from fastmcp import FastMCP, Context
from typing import List, Dict, Optional, Literal, Any
import json
import re
import subprocess
import tempfile
import os
import shutil

mcp = FastMCP("DiagramMCP - Dynamic diagram generation for software development")

# Diagram type definitions
DiagramType = Literal["flowchart", "sequence", "class", "er", "gantt", "gitgraph", "mindmap", "architecture", "network"]

@mcp.tool
def create_flowchart(
    title: str,
    nodes: List[Dict[str, str]],
    connections: List[Dict[str, str]],
    direction: str = "TD",
    theme: str = "default"
) -> str:
    """
    Create a Mermaid flowchart diagram.
    
    Args:
        title: Title of the flowchart
        nodes: List of nodes with 'id', 'label', and 'shape' (rectangle, circle, diamond, etc.)
        connections: List of connections with 'from', 'to', and optional 'label'
        direction: Flow direction (TD, LR, BT, RL)
        theme: Theme name (default, dark, forest, neutral)
    
    Returns:
        Mermaid flowchart syntax
    """
    
    # Validate nodes
    for node in nodes:
        if 'id' not in node or 'label' not in node:
            raise ValueError("Each node must have 'id' and 'label' fields")
    
    # Build Mermaid syntax
    mermaid_code = f"---\ntitle: {title}\n---\n"
    mermaid_code += f"flowchart {direction}\n"
    
    # Add nodes
    for node in nodes:
        node_id = node['id']
        label = node['label']
        shape = node.get('shape', 'rectangle')
        
        if shape == 'rectangle':
            mermaid_code += f"    {node_id}[{label}]\n"
        elif shape == 'circle':
            mermaid_code += f"    {node_id}(({label}))\n"
        elif shape == 'diamond':
            mermaid_code += f"    {node_id}{{{label}}}\n"
        elif shape == 'hexagon':
            mermaid_code += f"    {node_id}{{{{{label}}}}}\n"
        elif shape == 'parallelogram':
            mermaid_code += f"    {node_id}[/{label}/]\n"
        else:
            mermaid_code += f"    {node_id}[{label}]\n"
    
    # Add connections
    for conn in connections:
        from_node = conn['from']
        to_node = conn['to']
        label = conn.get('label', '')
        
        if label:
            mermaid_code += f"    {from_node} -->|{label}| {to_node}\n"
        else:
            mermaid_code += f"    {from_node} --> {to_node}\n"
    
    return mermaid_code

@mcp.tool
def create_sequence_diagram(
    title: str,
    participants: List[str],
    interactions: List[Dict[str, str]],
    theme: str = "default"
) -> str:
    """
    Create a Mermaid sequence diagram.
    
    Args:
        title: Title of the sequence diagram
        participants: List of participant names
        interactions: List of interactions with 'from', 'to', 'message', and 'type' (arrow, dotted, etc.)
        theme: Theme name
    
    Returns:
        Mermaid sequence diagram syntax
    """
    
    mermaid_code = f"---\ntitle: {title}\n---\n"
    mermaid_code += "sequenceDiagram\n"
    
    # Add participants
    for participant in participants:
        mermaid_code += f"    participant {participant}\n"
    
    mermaid_code += "\n"
    
    # Add interactions
    for interaction in interactions:
        from_p = interaction['from']
        to_p = interaction['to']
        message = interaction['message']
        arrow_type = interaction.get('type', 'arrow')
        
        if arrow_type == 'dotted':
            mermaid_code += f"    {from_p}-->{to_p}: {message}\n"
        elif arrow_type == 'activation':
            mermaid_code += f"    {from_p}->>+{to_p}: {message}\n"
        elif arrow_type == 'deactivation':
            mermaid_code += f"    {from_p}->>-{to_p}: {message}\n"
        else:
            mermaid_code += f"    {from_p}->>{to_p}: {message}\n"
    
    return mermaid_code

@mcp.tool
def create_class_diagram(
    title: str,
    classes: List[Dict[str, Any]],
    relationships: List[Dict[str, str]]
) -> str:
    """
    Create a Mermaid class diagram.
    
    Args:
        title: Title of the class diagram
        classes: List of classes with 'name', 'attributes', 'methods'
        relationships: List of relationships with 'from', 'to', 'type' (inheritance, composition, etc.)
    
    Returns:
        Mermaid class diagram syntax
    """
    
    mermaid_code = f"---\ntitle: {title}\n---\n"
    mermaid_code += "classDiagram\n"
    
    # Add classes
    for cls in classes:
        class_name = cls['name']
        attributes = cls.get('attributes', [])
        methods = cls.get('methods', [])
        
        mermaid_code += f"    class {class_name} {{\n"
        
        # Add attributes
        for attr in attributes:
            visibility = attr.get('visibility', '+')
            attr_type = attr.get('type', 'String')
            attr_name = attr['name']
            mermaid_code += f"        {visibility}{attr_type} {attr_name}\n"
        
        # Add methods
        for method in methods:
            visibility = method.get('visibility', '+')
            return_type = method.get('return_type', 'void')
            method_name = method['name']
            params = method.get('parameters', [])
            param_str = ', '.join([f"{p.get('type', 'String')} {p['name']}" for p in params])
            mermaid_code += f"        {visibility}{method_name}({param_str}) {return_type}\n"
        
        mermaid_code += "    }\n"
    
    # Add relationships
    for rel in relationships:
        from_class = rel['from']
        to_class = rel['to']
        rel_type = rel['type']
        
        if rel_type == 'inheritance':
            mermaid_code += f"    {from_class} --|> {to_class}\n"
        elif rel_type == 'composition':
            mermaid_code += f"    {from_class} --* {to_class}\n"
        elif rel_type == 'aggregation':
            mermaid_code += f"    {from_class} --o {to_class}\n"
        elif rel_type == 'association':
            mermaid_code += f"    {from_class} --> {to_class}\n"
        else:
            mermaid_code += f"    {from_class} -- {to_class}\n"
    
    return mermaid_code

@mcp.tool
def create_er_diagram(
    title: str,
    entities: List[Dict[str, Any]],
    relationships: List[Dict[str, str]]
) -> str:
    """
    Create a Mermaid Entity-Relationship diagram.
    
    Args:
        title: Title of the ER diagram
        entities: List of entities with 'name' and 'attributes'
        relationships: List of relationships with 'from', 'to', 'type', and 'cardinality'
    
    Returns:
        Mermaid ER diagram syntax
    """
    
    mermaid_code = f"---\ntitle: {title}\n---\n"
    mermaid_code += "erDiagram\n"
    
    # Add entities
    for entity in entities:
        entity_name = entity['name']
        attributes = entity.get('attributes', [])
        
        mermaid_code += f"    {entity_name} {{\n"
        for attr in attributes:
            attr_name = attr['name']
            attr_type = attr.get('type', 'string')
            is_key = attr.get('is_key', False)
            key_marker = 'PK' if is_key else ''
            mermaid_code += f"        {attr_type} {attr_name} {key_marker}\n"
        mermaid_code += "    }\n"
    
    # Add relationships
    for rel in relationships:
        from_entity = rel['from']
        to_entity = rel['to']
        cardinality = rel.get('cardinality', '||--||')
        label = rel.get('label', '')
        
        mermaid_code += f"    {from_entity} {cardinality} {to_entity} : {label}\n"
    
    return mermaid_code

@mcp.tool
def create_gantt_chart(
    title: str,
    sections: List[Dict[str, Any]]
) -> str:
    """
    Create a Mermaid Gantt chart for project planning.
    
    Args:
        title: Title of the Gantt chart
        sections: List of sections with 'name' and 'tasks'
    
    Returns:
        Mermaid Gantt chart syntax
    """
    
    mermaid_code = f"---\ntitle: {title}\n---\n"
    mermaid_code += "gantt\n"
    mermaid_code += f"    title {title}\n"
    mermaid_code += "    dateFormat  YYYY-MM-DD\n"
    
    for section in sections:
        section_name = section['name']
        tasks = section.get('tasks', [])
        
        mermaid_code += f"    section {section_name}\n"
        
        for task in tasks:
            task_name = task['name']
            status = task.get('status', '')
            start_date = task.get('start_date', '')
            duration = task.get('duration', '1d')
            
            status_marker = ''
            if status == 'done':
                status_marker = 'done, '
            elif status == 'active':
                status_marker = 'active, '
            elif status == 'critical':
                status_marker = 'crit, '
            
            if start_date:
                mermaid_code += f"    {task_name} :{status_marker}{task_name}, {start_date}, {duration}\n"
            else:
                mermaid_code += f"    {task_name} :{status_marker}{duration}\n"
    
    return mermaid_code

@mcp.tool
def create_architecture_diagram(
    title: str,
    components: List[Dict[str, str]],
    connections: List[Dict[str, str]],
    layers: Optional[List[str]] = None
) -> str:
    """
    Create a system architecture diagram using Mermaid.
    
    Args:
        title: Title of the architecture diagram
        components: List of components with 'id', 'name', 'type' (service, database, frontend, etc.)
        connections: List of connections with 'from', 'to', 'protocol' (HTTP, TCP, etc.)
        layers: Optional list of layer names for organizing components
    
    Returns:
        Mermaid architecture diagram syntax
    """
    
    mermaid_code = f"---\ntitle: {title}\n---\n"
    mermaid_code += "graph TB\n"
    
    # Add layer subgraphs if provided
    if layers:
        for i, layer in enumerate(layers):
            mermaid_code += f"    subgraph {layer}\n"
            # Filter components for this layer
            layer_components = [c for c in components if c.get('layer') == layer]
            for comp in layer_components:
                comp_id = comp['id']
                comp_name = comp['name']
                comp_type = comp.get('type', 'service')
                
                if comp_type == 'database':
                    mermaid_code += f"        {comp_id}[('{comp_name}')]\n"
                elif comp_type == 'frontend':
                    mermaid_code += f"        {comp_id}['{comp_name}']\n"
                elif comp_type == 'api':
                    mermaid_code += f"        {comp_id}{{'{comp_name}'}}\n"
                else:
                    mermaid_code += f"        {comp_id}['{comp_name}']\n"
            mermaid_code += "    end\n"
    else:
        # Add components without layers
        for comp in components:
            comp_id = comp['id']
            comp_name = comp['name']
            comp_type = comp.get('type', 'service')
            
            if comp_type == 'database':
                mermaid_code += f"    {comp_id}[('{comp_name}')]\n"
            elif comp_type == 'frontend':
                mermaid_code += f"    {comp_id}['{comp_name}']\n"
            elif comp_type == 'api':
                mermaid_code += f"    {comp_id}{{'{comp_name}'}}\n"
            else:
                mermaid_code += f"    {comp_id}['{comp_name}']\n"
    
    # Add connections
    for conn in connections:
        from_comp = conn['from']
        to_comp = conn['to']
        protocol = conn.get('protocol', '')
        
        if protocol:
            mermaid_code += f"    {from_comp} -->|{protocol}| {to_comp}\n"
        else:
            mermaid_code += f"    {from_comp} --> {to_comp}\n"
    
    return mermaid_code

@mcp.tool
def validate_diagram(diagram_code: str, diagram_type: DiagramType) -> Dict[str, Any]:
    """
    Validate diagram syntax and provide feedback.
    
    Args:
        diagram_code: The diagram code to validate
        diagram_type: Type of diagram to validate
    
    Returns:
        Validation result with 'valid' boolean and 'errors' list
    """
    
    errors = []
    
    # Basic syntax validation
    if not diagram_code.strip():
        errors.append("Diagram code is empty")
        return {"valid": False, "errors": errors}
    
    # Check for diagram type specific syntax
    if diagram_type == "flowchart":
        if "flowchart" not in diagram_code:
            errors.append("Missing 'flowchart' declaration")
        
        # Check for valid direction
        directions = ["TD", "TB", "BT", "RL", "LR"]
        if not any(f"flowchart {d}" in diagram_code for d in directions):
            errors.append("Invalid or missing flowchart direction")
    
    elif diagram_type == "sequence":
        if "sequenceDiagram" not in diagram_code:
            errors.append("Missing 'sequenceDiagram' declaration")
    
    elif diagram_type == "class":
        if "classDiagram" not in diagram_code:
            errors.append("Missing 'classDiagram' declaration")
    
    elif diagram_type == "er":
        if "erDiagram" not in diagram_code:
            errors.append("Missing 'erDiagram' declaration")
    
    elif diagram_type == "gantt":
        if "gantt" not in diagram_code:
            errors.append("Missing 'gantt' declaration")
    
    # Check for balanced brackets and parentheses
    brackets = {"[": "]", "{": "}", "(": ")"}
    stack = []
    
    for char in diagram_code:
        if char in brackets.keys():
            stack.append(char)
        elif char in brackets.values():
            if not stack:
                errors.append("Unmatched closing bracket/parenthesis")
                break
            last_open = stack.pop()
            if brackets[last_open] != char:
                errors.append("Mismatched brackets/parentheses")
                break
    
    if stack:
        errors.append("Unclosed brackets/parentheses")
    
    return {
        "valid": len(errors) == 0,
        "errors": errors,
        "line_count": len(diagram_code.split('\n')),
        "character_count": len(diagram_code)
    }

@mcp.tool
def run_mermaid_cli(
    diagram_code: str,
    output_format: Literal["png", "svg", "pdf"] = "png",
    output_file: str = "diagram",
    theme: str = "default",
    width: int = 1200,
    height: int = 800,
    background_color: str = "white"
) -> Dict[str, Any]:
    """
    Execute Mermaid CLI commands to generate diagrams.
    
    Args:
        diagram_code: The Mermaid diagram code
        output_format: Output format (png, svg, pdf)
        output_file: Output filename (without extension)
        theme: Mermaid theme (default, dark, forest, neutral)
        width: Image width in pixels
        height: Image height in pixels
        background_color: Background color for the image
    
    Returns:
        Dictionary with command execution results and file paths
    """
    
    # Check if mmdc (Mermaid CLI) is available
    if not shutil.which("mmdc"):
        return {
            "success": False,
            "error": "Mermaid CLI (mmdc) not found",
            "message": "Please install @mermaid-js/mermaid-cli: npm install -g @mermaid-js/mermaid-cli",
            "command": "npm install -g @mermaid-js/mermaid-cli"
        }
    
    try:
        # Create temporary input file
        with tempfile.NamedTemporaryFile(mode='w', suffix='.mmd', delete=False) as f:
            f.write(diagram_code)
            input_file = f.name
        
        # Set output file path
        output_path = f"{output_file}.{output_format}"
        
        # Build mmdc command
        cmd = [
            "mmdc",
            "-i", input_file,
            "-o", output_path,
            "-t", theme,
            "-w", str(width),
            "-H", str(height),
            "-b", background_color
        ]
        
        # Execute command
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            timeout=30
        )
        
        # Clean up input file
        if os.path.exists(input_file):
            os.unlink(input_file)
        
        if result.returncode == 0:
            file_size = os.path.getsize(output_path) if os.path.exists(output_path) else 0
            return {
                "success": True,
                "output_file": output_path,
                "format": output_format,
                "theme": theme,
                "width": width,
                "height": height,
                "file_size_bytes": file_size,
                "command": " ".join(cmd),
                "message": f"Diagram generated successfully: {output_path}"
            }
        else:
            return {
                "success": False,
                "error": result.stderr,
                "stdout": result.stdout,
                "command": " ".join(cmd),
                "message": "Mermaid CLI execution failed"
            }
            
    except subprocess.TimeoutExpired:
        return {
            "success": False,
            "error": "Command timed out",
            "message": "Mermaid CLI execution timed out after 30 seconds"
        }
    except Exception as e:
        return {
            "success": False,
            "error": str(e),
            "message": "Failed to execute Mermaid CLI"
        }

@mcp.tool
def export_diagram(
    diagram_code: str,
    format: Literal["svg", "png", "pdf", "html"] = "svg",
    theme: str = "default"
) -> Dict[str, str]:
    """
    Export diagram to various formats (returns instructions for rendering).
    
    Args:
        diagram_code: The Mermaid diagram code
        format: Export format
        theme: Theme to apply
    
    Returns:
        Export instructions and metadata
    """
    
    export_info = {
        "diagram_code": diagram_code,
        "format": format,
        "theme": theme,
        "instructions": "",
        "tools_needed": ""
    }
    
    if format == "svg":
        export_info["instructions"] = "Use Mermaid CLI: mmdc -i input.mmd -o output.svg"
        export_info["tools_needed"] = "@mermaid-js/mermaid-cli"
    elif format == "png":
        export_info["instructions"] = "Use run_mermaid_cli tool for PNG generation"
        export_info["tools_needed"] = "@mermaid-js/mermaid-cli"
    elif format == "pdf":
        export_info["instructions"] = "Use Mermaid CLI: mmdc -i input.mmd -o output.pdf"
        export_info["tools_needed"] = "@mermaid-js/mermaid-cli, puppeteer"
    elif format == "html":
        export_info["instructions"] = "Embed in HTML with Mermaid.js library"
        export_info["tools_needed"] = "mermaid (CDN or npm)"
    
    return export_info

if __name__ == "__main__":
    mcp.run(transport="stdio")
